#include "ns3/boolean.h"
#include "ns3/command-line.h"
#include "ns3/config.h"
#include "ns3/double.h"
#include "ns3/eht-phy.h"
#include "ns3/enum.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/log.h"
#include "ns3/mobility-helper.h"
#include "ns3/multi-model-spectrum-channel.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/rng-seed-manager.h"
#include "ns3/spectrum-wifi-helper.h"
#include "ns3/ssid.h"
#include "ns3/string.h"
#include "ns3/udp-client-server-helper.h"
#include "ns3/uinteger.h"
#include "ns3/wifi-acknowledgment.h"
#include "ns3/yans-wifi-channel.h"
#include "ns3/yans-wifi-helper.h"

#include <array>
#include <functional>
#include <numeric>

using namespace ns3;

/**
 * \param udp true if UDP is used, false if TCP is used
 * \param serverApp a container of server applications
 * \param payloadSize the size in bytes of the packets
 * \return the bytes received by each server application
 */
std::vector<uint64_t>
GetRxBytes(bool udp, const ApplicationContainer& serverApp, uint32_t payloadSize)
{
    std::vector<uint64_t> rxBytes(serverApp.GetN(), 0);
    if (udp)
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = payloadSize * DynamicCast<UdpServer>(serverApp.Get(i))->GetReceived();
        }
    }
    else
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = DynamicCast<PacketSink>(serverApp.Get(i))->GetTotalRx();
        }
    }
    return rxBytes;
}

int
main(int argc, char* argv[])
{
    bool useRts{false}; // use RTS/CTS
    if (useRts)
    {
        Config::SetDefault("ns3::WifiRemoteStationManager::RtsCtsThreshold", StringValue("0"));
        Config::SetDefault("ns3::WifiDefaultProtectionManager::EnableMuRts", BooleanValue(true));
    }

    std::string dlAckSeqType{"NO-OFDMA"};
    bool enableUlOfdma{true};
    bool enableBsrp{false};
    Time accessReqInterval{0};
    if (dlAckSeqType == "ACK-SU-FORMAT")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_BAR_BA_SEQUENCE));
    }
    else if (dlAckSeqType == "MU-BAR")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_TF_MU_BAR));
    }
    else if (dlAckSeqType == "AGGR-MU-BAR")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_AGGREGATE_TF));
    }
    else if (dlAckSeqType != "NO-OFDMA")
    {
        NS_ABORT_MSG("Invalid DL ack sequence type (must be NO-OFDMA, ACK-SU-FORMAT, MU-BAR or "
                     "AGGR-MU-BAR)");
    }

    WifiHelper wifi;
    wifi.SetStandard(WIFI_STANDARD_80211be); // 设置WiFi协议为802.11be

    int mcs{13}; // MCS index

    struct channelInfo
    {
        int channelNumber;
        int channelWidth;
        double channelFrequency;
        int primaryIndex;
    };

    std::vector<channelInfo> channels = {{36, 20, 5, 0},
                                         {40, 20, 5, 0}}; // channel info vector
    std::vector<std::string> channelStr;
    std::vector<FrequencyRange> freqRanges;
    std::string dataModeStr = "EhtMcs" + std::to_string(mcs);
    std::string ctrlRateStr;
    uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
    int nLinks = 0;

    std::vector<SpectrumWifiPhyHelper> singlePhys;
    SpectrumWifiPhyHelper phy(channels.size());
    phy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
    phy.Set("ChannelSwitchDelay", TimeValue(MicroSeconds(100)));

    for (auto channel : channels)
    {
        channelStr.resize(nLinks + 1);
        freqRanges.resize(nLinks + 1);

        double freq = channel.channelFrequency;
        int channelWidth = channel.channelWidth;
        channelStr[nLinks] = "{" + std::to_string(channel.channelNumber) + ", " +
                             std::to_string(channelWidth) + ", "; // 信道号不要忘记设置啊喂(#`O′)
        if (freq == 6)
        {
            channelStr[nLinks] += "BAND_6GHZ, 0}";
            freqRanges[nLinks] = WIFI_SPECTRUM_6_GHZ;
            Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss",
                               DoubleValue(48));
            wifi.SetRemoteStationManager(nLinks,
                                         "ns3::ConstantRateWifiManager",
                                         "DataMode",
                                         StringValue(dataModeStr),
                                         "ControlMode",
                                         StringValue(dataModeStr));
        }
        else if (freq == 5)
        {
            channelStr[nLinks] += "BAND_5GHZ, 0}";
            freqRanges[nLinks] = WIFI_SPECTRUM_5_GHZ;
            ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
            wifi.SetRemoteStationManager(nLinks,
                                         "ns3::ConstantRateWifiManager",
                                         "DataMode",
                                         StringValue(dataModeStr),
                                         "ControlMode",
                                         StringValue(ctrlRateStr));
        }
        else if (freq == 2.4)
        {
            channelStr[nLinks] += "BAND_2_4GHZ, 0}";
            freqRanges[nLinks] = WIFI_SPECTRUM_2_4_GHZ;
            Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss",
                               DoubleValue(40));
            ctrlRateStr = "ErpOfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
            wifi.SetRemoteStationManager(nLinks,
                                         "ns3::ConstantRateWifiManager",
                                         "DataMode",
                                         StringValue(dataModeStr),
                                         "ControlMode",
                                         StringValue(ctrlRateStr));
        }

        /* 创建多个 Channel */
        auto spectrumChannel = CreateObject<MultiModelSpectrumChannel>();
        auto lossModel = CreateObject<LogDistancePropagationLossModel>();
        spectrumChannel->AddPropagationLossModel(lossModel);
        SpectrumWifiPhyHelper tempPhy(1);
        tempPhy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
        tempPhy.SetChannel(spectrumChannel);
        tempPhy.Set(0, "ChannelSettings", StringValue(channelStr[nLinks]));

        singlePhys.push_back(tempPhy);

        phy.SetChannel(spectrumChannel);
        phy.Set(nLinks, "ChannelSettings", StringValue(channelStr[nLinks]));

        nLinks++;
    }

    /**
     * @brief sta的数量，取值为channel的数量加一是为了在每个channel上都要有sta和ap连接，
     *
     * 还剩一个在所有Link上建立连接
     */
    std::size_t nStations = nLinks + 1; // 先暂时测试一下MLD能否正常工作
    NodeContainer wifiStaNodes;
    wifiStaNodes.Create(nStations);
    NodeContainer wifiApNode;
    wifiApNode.Create(1);
    Ssid ssid = Ssid("ns3-80211be"); // wifi网络的名称

    // 设置mac层
    WifiMacHelper mac;
    wifi.ConfigEhtOptions("EmlsrActivated", BooleanValue(true)); // 打开WiFi对emlsr的支持
    // 设置mac层，供sta使用，并且启动emlsr
    mac.SetType("ns3::StaWifiMac", "Ssid", SsidValue(ssid));
    NetDeviceContainer staDevices;

    for (size_t i = 1; i < nStations; i++)
    {
        /* STAi STAi 与 AP 在第 i - 1 条Link上建立连
         * 考虑到在help函数中也是creat类之后再安装到device中，应该在我看来，各个device的phy的channel应该没有血缘关系
         * 在wifihelper的install方法中，会检查WiFi的Link数量和phyhelper的数量是否一致，因此这里还需要重写一遍wifihelper
         * 的配置
         */

        int linkId = i - 1;
        double freq = channels[linkId].channelFrequency;

        WifiHelper tempWifiHelper;
        tempWifiHelper.SetStandard(WIFI_STANDARD_80211be);
        if (freq == 6)
        {
            Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss",
                               DoubleValue(48));
            tempWifiHelper.SetRemoteStationManager((int)0,
                                                   "ns3::ConstantRateWifiManager",
                                                   "DataMode",
                                                   StringValue(dataModeStr),
                                                   "ControlMode",
                                                   StringValue(dataModeStr));
        }
        else if (freq == 5)
        {
            ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
            tempWifiHelper.SetRemoteStationManager((int)0,
                                                   "ns3::ConstantRateWifiManager",
                                                   "DataMode",
                                                   StringValue(dataModeStr),
                                                   "ControlMode",
                                                   StringValue(ctrlRateStr));
        }
        else if (freq == 2.4)
        {
            Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss",
                               DoubleValue(40));
            ctrlRateStr = "ErpOfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
            tempWifiHelper.SetRemoteStationManager((int)0,
                                                   "ns3::ConstantRateWifiManager",
                                                   "DataMode",
                                                   StringValue(dataModeStr),
                                                   "ControlMode",
                                                   StringValue(ctrlRateStr));
        }

        staDevices.Add(tempWifiHelper.Install(singlePhys[linkId], mac, wifiStaNodes.Get(i)));
    }

    std::string emlsrLinks = "0,1";
    uint32_t paddingDelayUsec{128};    // emlsr的padding延迟
    uint32_t transitionDelayUsec{128}; // emlsr的transition延迟
    bool switchAuxPhy{true};           // 是否切换到辅助信道
    mac.SetEmlsrManager("ns3::DefaultEmlsrManager",
                        "EmlsrLinkSet",
                        StringValue(emlsrLinks),
                        "EmlsrPaddingDelay",
                        TimeValue(MicroSeconds(paddingDelayUsec)),
                        "EmlsrTransitionDelay",
                        TimeValue(MicroSeconds(transitionDelayUsec)),
                        "SwitchAuxPhy",
                        BooleanValue(switchAuxPhy));
    // STA0 STA0 与 AP 在所有的Link上都建立连接 
    staDevices.Add(wifi.Install(phy, mac, wifiStaNodes.Get(0)));

    // AP设置
    if (dlAckSeqType != "NO-OFDMA")
    {
        mac.SetMultiUserScheduler("ns3::RrMultiUserScheduler",
                                  "EnableUlOfdma",
                                  BooleanValue(enableUlOfdma),
                                  "EnableBsrp",
                                  BooleanValue(enableBsrp),
                                  "AccessReqInterval",
                                  TimeValue(accessReqInterval));
    }
    mac.SetType("ns3::ApWifiMac",
                "EnableBeaconJitter",
                BooleanValue(false),
                "Ssid",
                SsidValue(ssid));
    NetDeviceContainer apDevice = wifi.Install(phy, mac, wifiApNode);

    // 下面是更高层的一些调整

    // 随机数的设置
    RngSeedManager::SetSeed(1);
    RngSeedManager::SetRun(1);
    int64_t streamNumber = 100;
    streamNumber += wifi.AssignStreams(apDevice, streamNumber);
    streamNumber += wifi.AssignStreams(staDevices, streamNumber);

    // Set guard interval and MPDU buffer size
    int gi = 3200; // 符号间的间隔
    bool useExtendedBlockAck{false};
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/GuardInterval",
                TimeValue(NanoSeconds(gi)));
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/MpduBufferSize",
                UintegerValue(useExtendedBlockAck ? 256 : 64));

    // mobility.位置模型
    MobilityHelper mobility;
    Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator>();
    double distance = 0.5; // 结点间的距离，米
    for (std::size_t i = 0; i < nStations + 1; i++)
    {
        positionAlloc->Add(Vector(distance * i, 0.0, 0.0));
    }
    mobility.SetMobilityModel("ns3::ConstantPositionMobilityModel");
    mobility.Install(wifiApNode);
    mobility.Install(wifiStaNodes);

    /* Internet stack*/
    InternetStackHelper stack;
    stack.Install(wifiApNode);
    stack.Install(wifiStaNodes);

    Ipv4AddressHelper address;
    address.SetBase("192.168.1.0", "255.255.255.0");
    Ipv4InterfaceContainer staNodeInterfaces;
    Ipv4InterfaceContainer apNodeInterface;

    staNodeInterfaces = address.Assign(staDevices);
    apNodeInterface = address.Assign(apDevice);

    /* Setting applications */
    ApplicationContainer serverApp;

    /* STA向AP发送大量UDP数据 */

    double simulationTime = 10.0; // seconds
    uint16_t portAp = 9;
    UdpServerHelper serverAp(portAp);
    serverApp = serverAp.Install(wifiApNode);
    serverApp.Start(Seconds(0.0));
    serverApp.Stop(Seconds(simulationTime + 1));
    for (std::size_t i = 0; i < nStations; i++)
    {
        UdpClientHelper clientSta(apNodeInterface.GetAddress(0), portAp);
        clientSta.SetAttribute("MaxPackets", UintegerValue(4294967295U));
        clientSta.SetAttribute("Interval", TimeValue(Time("0.00001"))); // packets/s
        clientSta.SetAttribute("PacketSize", UintegerValue(700));
        ApplicationContainer clientApp;
        clientApp = clientSta.Install(wifiStaNodes.Get(i));
        clientApp.Start(Seconds(1.0));
        clientApp.Stop(Seconds(simulationTime + 1));
    }

    // 查看server接收数据包的速度
    std::cout << "MCS value"
              << "\t\t"
              << "Channel width"
              << "\t\t"
              << "GI"
              << "\t\t\t"
              << "Throughput" << '\n';

    Mac48Address apMac = Mac48Address::ConvertFrom(wifiApNode.Get(0)->GetDevice(0)->GetAddress());
    std::cout << "AP MAC address: " << apMac << std::endl;
    std::string pcapFileName = "MultiSta-ap";
    phy.EnablePcap(pcapFileName, apDevice, true);

    for (std::size_t i = 0; i < nStations; i++)
    {
        Mac48Address staMac =
            Mac48Address::ConvertFrom(wifiStaNodes.Get(i)->GetDevice(0)->GetAddress());
        std::cout << "STA" << i << " MAC address: " << staMac << std::endl;
        pcapFileName = "MultiSta-sta" + std::to_string(i);
        phy.EnablePcap(pcapFileName, staDevices.Get(i), true);
    }

    Simulator::Stop(Seconds(simulationTime + 1));
    Simulator::Run();

    std::vector<uint64_t> cumulRxBytes(1, 0);
    cumulRxBytes = GetRxBytes(true, serverApp, 700);
    uint64_t rxBytes =
        std::accumulate(cumulRxBytes.cbegin(), cumulRxBytes.cend(), 0); // 向量中元素累加
    double throughput = (rxBytes * 8) / (simulationTime * 1000000.0);   // Mbit/s

    Simulator::Destroy();
    std::cout << mcs << "\t\t\t"
              << "20 MHz\t\t\t" << gi << " ns\t\t\t" << throughput << " Mbit/s" << std::endl;

    return 0;
}